{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Modeling Iris Data Using Deep Neural Network in PyTorch\n",
    "\n",
    "### Anirban Sinha, PhD candidate, Stony Brook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch version: 1.0.0\n",
      "Numpy version: 1.15.4\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Check versions\n",
    "print(\"PyTorch version: {}\".format(torch.__version__))\n",
    "print(\"Numpy version: {}\".format(np.__version__))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean of features: [5.84333333 3.054      3.75866667 1.19866667]\n",
      "Std of features: [0.82530129 0.43214658 1.75852918 0.76061262]\n"
     ]
    }
   ],
   "source": [
    "# Load Iris Dataset\n",
    "FILE_PATH = \"/home/nobug-ros/dev_ml/PyTorch-for-Iris-Dataset/\"\n",
    "MAIN_FILE_NAME = \"iris_dataset.txt\"\n",
    "TRAIN_FILE_NAME = \"iris_train_dataset.txt\"\n",
    "TEST_FILE_NAME = \"iris_test_dataset.txt\"\n",
    "NUM_FEATURE = 4\n",
    "NUM_CLASS = 3\n",
    "num_epoch = 20\n",
    "\n",
    "data = np.loadtxt(FILE_PATH+MAIN_FILE_NAME, delimiter=\",\")\n",
    "mean_data = np.mean(data[:,:NUM_FEATURE], axis=0)\n",
    "std_data = np.std(data[:,:NUM_FEATURE], axis=0)\n",
    "\n",
    "train_data = np.loadtxt(FILE_PATH+TRAIN_FILE_NAME, delimiter=\",\")\n",
    "test_data = np.loadtxt(FILE_PATH+TEST_FILE_NAME, delimiter=\",\")\n",
    "\n",
    "print(\"Mean of features: {}\".format(mean_data))\n",
    "print(\"Std of features: {}\".format(std_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standardize (Pre-process) train and test data\n",
    "for j in range(NUM_FEATURE):\n",
    "    for i in range(train_data.shape[0]):\n",
    "        train_data[i, j] = (train_data[i, j] - mean_data[j])/std_data[j]\n",
    "    for i in range(test_data.shape[0]):\n",
    "        test_data[i, j] = (test_data[i, j] - mean_data[j])/std_data[j] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert scaled test and train data into PyTorch tensor\n",
    "train_data = torch.Tensor(train_data)\n",
    "test_data = torch.Tensor(test_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start building the Neural Network using PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select device to compute cpu/gpu\n",
    "device = torch.device('cpu')\n",
    "# device = torch.device('gpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "D_in, H, D_out = NUM_FEATURE, 8, 3\n",
    "# Use the nn package to define our model and loss function.\n",
    "model = torch.nn.Sequential(\n",
    "          torch.nn.Linear(D_in, H),\n",
    "          torch.nn.ReLU(),\n",
    "          torch.nn.Linear(H, D_out),\n",
    "          torch.nn.Softmax(dim=0),\n",
    "        )\n",
    "# MSELoss\n",
    "loss_fn = torch.nn.MSELoss(reduction='mean')\n",
    "\n",
    "# We use Adam; the optim package contains many other optimization algoriths. \n",
    "# The first argument to the Adam constructor tells the optimizer which Tensors it should update.\n",
    "learning_rate = 1e-4\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Neural Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3XuclnP+x/HXp6M0DkVmdUA5RBZpppKIkd0KK35bK4c2h5Qftay1aBGy7BLr2FLrELuYsD+EbFlGjtFBB8khScI6k5BKn98f36t1GzPdV3PPNdc9M+/n43E/5r6u+7ru+z1XM/Ppur7X9/s1d0dERGRDGqQdQERE8p+KhYiIZKViISIiWalYiIhIVioWIiKSlYqFiIhkpWIhIiJZqViIiEhWKhYiIpJVo7QDVJett97ad9hhhyrv/9VXX9G8efPqC1TNlC83ypcb5ctNPuebPXv2x+7eKuuG7l4nHkVFRZ6LsrKynPZPmvLlRvlyo3y5yed8wCyP8TdWl6FERCQrFQsREclKxUJERLJSsRARkaxULEREJCsVCxERyUrFAmDVqrQTiIjkNRWLd96Bjh35yaOPpp1ERCRvqVhMmgTLlrHrFVfAyJGwZk3aiURE8o6KxVlnwYQJrGvUCG64AQ4+GD78MO1UIiJ5RcUC4OSTmXvNNbDttvDUU1BcDLNnp51KRCRvqFhEVuy+O8yaBT16hHaMnj3hjjvSjiUikhdULDK1bg1lZXDyyfDttzBkCJxxhtoxRKTeU7Eor2lTmDABbroJGjeGa6+FPn3go4/STiYikhoVi8oMHx7OMgoLw9fiYnjppbRTiYikQsViQ3r2DA3d3brBsmVh+a670k4lIlLjVCyyadMGpk+HE0+Eb76BY48Nt9uuXZt2MhGRGqNiEccmm8DNN4d+GI0awVVXQd++8MknaScTEakRKhZxmcFpp8Hjj8M224SvxcUwb17ayUREEqdisbF69Qr9MYqLYenS0C9j0qS0U4mIJCrRYmFmfc3sNTNbbGbnVvB6LzObY2ZrzWxAudcuN7OXo8dRSebcaO3ahZ7eQ4aEdoxBg+Dss+G779JOJiKSiMSKhZk1BMYB/YBOwNFm1qncZsuA44G7yu17KNAF6Ax0B35vZpsnlbVKmjWD224L/TAaNoSxY6FfP7VjiEidlOSZRTdgsbsvcffVQCnQP3MDd1/q7vOBdeX27QRMd/e17v4VMA/om2DWqjGD3/wG/v1vaNUKHnssXJ6aOzftZCIi1SrJYtEGeCdjeXm0Lo55QD8z29TMtgZKgHbVnK/6HHhg6I+xvh1j333VH0NE6hRz92Te2Gwg0Mfdh0bLg4Fu7j6ygm0nAg+7+30Z684DBgIfAR8CL7r7teX2GwYMAygsLCwqLS2tct6VK1dSUFBQ5f0BGqxezc5XX822//oXAO8MHMiS4cPxhg1zet/qypck5cuN8uVG+aqupKRktrsXZ93Q3RN5AD2AqRnLo4BRlWw7ERiwgfe6CzhkQ59XVFTkuSgrK8tp//9at8593Dj3Ro3cwb2kxP3DD3N+22rLlxDly43y5Ub5qg6Y5TH+pid5GWomsLOZtTezJsAgYHKcHc2soZltFT3fE9gTmJZY0upkBqee+uNxpebMSTuZiEiVJVYs3H0tMAKYCiwC7nH3hWY2xswOBzCzrma2nHC5abyZLYx2bww8bWavABOA46L3qz322y+0Y+yzz/fjSml+DBGppRol+ebuPgWYUm7d6IznM4G2Fey3inBHVO3Wpg08+WS4Y2rChNAvY9asMFxI48ZppxMRiU09uJPWtCmMHx8ejRvD9deHeb4/+CDtZCIisalY1JRhw8Lotevn+S4qghdfTDuViEgsKhY1qUeP0I7Rsye8+y7svz/cemvaqUREslKxqGnbbgtPPBHumFq9Gk466fvnIiJ5SsUiDU2awLhx4ayiaVO48UY46CC1Y4hI3lKxSNMJJ8DTT0PbtvDss6EdY9astFOJiPyIikXaunYNBWJ9O8Z++8Hf/552KhGRH1CxyAeFhaEdY/hw+PZb+PWv4Xe/0zzfIpI3VCzyRZMmcNNN4dGoEfzlL2F+jE8/TTuZiIiKRd4ZPjyMJ7XNNmGejK5dYcGCtFOJSD2nYpGP9tsvtGMUFcGSJdCjB1s/9VTaqUSkHlOxyFft2oU7pY47Dr76ip9eeCGMHg3ryk8qKCKSPBWLfNasWRip9sor8QYN4JJL4MgjYcWKtJOJSD2jYpHvzOB3v2P+n/8MLVrA5Mlh2PPXX087mYjUIyoWtcRnXbvCzJmw++6waBF06wbR9K0iIklTsahNdtwRnn8+XIr64gs45BC4/HJIaB51EZH1VCxqm802g/vug4svDkXi3HPhmGPg66/TTiYidZiKRW3UoEG4M+qBB6CgAEpLw3Ahb7+ddjIRqaMSLRZm1tfMXjOzxWZ2bgWv9zKzOWa21swGlHvtCjNbaGaLzOw6M7Mks9ZK/fvDCy/ATjvB3LlQXBymcRURqWaJFQszawiMA/oR5tM+2szKz6u9DDgeuKvcvvsCPYE9gZ8CXYEDkspaq3XqFGbc69sXPv44TNl6/fVqxxCRapXkmUU3YLG7L3H31UAp0D9zA3df6u7zgfI9zRzYBGgCNAUaA5rsoTItWsDDD8M558B338FvfgMnngirVqWdTETqCPOE/gcaXVbq6+5Do+XBQHd3H1HBthOBh939vox1VwJDAQNucPfzKthvGDAMoLCwsKi0tLTKeVeuXElBQUGV909a3HytnniCXa+4gobffsuKXXfl5TFjWN2qVd7kS4vy5Ub5cpPP+UpKSma7e3HWDd09kQcwELg5Y3kwcH0l204EBmQs7wQ8AhREj+eBXhv6vKKiIs9FWVlZTvsnbaPyvfSS+/bbu4N7YaH7M88kFeu/6tTxS4Hy5Ub5qg6Y5TH+pid5GWo50C5juS3wXsx9jwRmuPtKd18JPArsU8356q7OncNAhCUlYarWkhKYMCHtVCJSiyVZLGYCO5tZezNrAgwCJsfcdxlwgJk1MrPGhMbtRQnlrJu23hqmTYPTT4c1a8LQ56ecAqtXp51MRGqhxIqFu68FRgBTCX/o73H3hWY2xswOBzCzrma2nHDJaryZLYx2vw94E1gAzAPmuftDSWWtsxo1gmuugYkToWlTGD8eDjoI/vOftJOJSC3TKMk3d/cpwJRy60ZnPJ9JuDxVfr/vgOFJZqtXhgwJt9geeSQ8+2zoj3H//WFiJRGRGNSDu77o2hVmzw49vd99F/bfH26/Pe1UIlJLqFjUJ4WF8MQToe3i22/h+OO/b9MQEdkAFYv6pkkTuPHG0H7RuDFcdx306RN6f4uIVELFor4aNgzKysLZRllZaMeYMyftVCKSp1Qs6rOePUM7RrduYcTanj3DNK4iIuWoWNR3bdrAU0/B0KFhLKkhQ2DkSLVjiMgPqFhI6IPxt799345xww3Qu7f6Y4jIf6lYyPeGDYPp06F1a3j6aSgqghkz0k4lInlAxUJ+qEeP0I6x337w3nvQq5fGlRIRFQupwE9+Ao8/DiNGfD+u1LBhoW+GiNRLKhZSsSZNwox7EyfCJpuENo0DDoDly9NOJiIpULGQDRsyBJ55BrbbLsz3XVQU7p4SkXpFxUKyKyoK82McdBB8+GG4U+q66zTPt0g9omIh8bRqBVOnwllnwdq1YUypIUPgm2/STiYiNUDFQuJr1AjGjoW774ZNN4W//z30+l66NO1kIpIwFQvZeIMGwfPPQ4cO8NJLUFxMi1mz0k4lIglSsZCq2XPP0I7Rrx988gl7nn02XHoprFuXdjIRSUCixcLM+prZa2a22MzOreD1XmY2x8zWmtmAjPUlZjY347HKzI5IMqtUQYsW8NBDcMEFmDucfz4ccQR8/nnayUSkmm2wWJhZQzP7bVXe2MwaAuOAfkAn4Ggz61Rus2XA8cBdmSvdvczdO7t7Z+Ag4GtgWlVySMIaNoQxY5h/2WWw5ZaheBQVwdy5aScTkWq0wWIRzYXdv4rv3Q1Y7O5L3H01UFr+vdx9qbvPBzZ07WIA8Ki7f13FHFIDPu3RI8yHsffesGRJGDZE07aK1BlxLkM9a2Y3mNn+ZtZl/SPGfm2AdzKWl0frNtYg4O4q7Cc1rX17ePZZOPHEMNz58ceHoUJWrUo7mYjkyDxLxyozK6tgtbv7QVn2Gwj0cfeh0fJgoJu7j6xg24nAw+5+X7n12wLzgdbu/qMJFsxsGDAMoLCwsKi0tHSD38uGrFy5koKCgirvn7Talu8njzzCLtdeS4M1a1jRsSMLL7qIb3/yk7zJl2+ULzfKV3UlJSWz3b0464bunsgD6AFMzVgeBYyqZNuJwIAK1p8OTIjzeUVFRZ6LsrKynPZPWq3MN2uW+w47uIN7y5bu//pXjedar1YevzyifLnJ53zALI/xNzbrZSgz28LM/mJms6LHVWa2RYyCNRPY2czam1kTwuWkyTH2y3Q0ugRVexUVheHODzkEPv003GY7ZoxurxWpheK0WdwKfAn8KnqsAG7LtpO7rwVGAFOBRcA97r7QzMaY2eEAZtbVzJYDA4HxZrZw/f5mtgPQDpi+Md+Q5JmWLcMdUmPGhOULL4TDDgvFQ0RqjUYxttnR3X+ZsXyxmcW6L9LdpwBTyq0bnfF8JtC2kn2XUrUGcck3DRrABRdAt25wzDHw6KPhrOOf/4Quce6VEJG0xTmz+MbM9lu/YGY9AY0eJxuvT59we21xcRhPat994ZZb0k4lIjHEKRanAOPMbKmZLQVuAIYnmkrqru23D/NjDB8eZt4bOhROOkmj14rkuWw9uBsAHd19L2BPYE9339tDRzqRqmnaFG666ftZ+G69NZxlLF6cdjIRqUS2HtzrCI3UuPsKd19RI6mkfhgyBGbMgB13DMODrG/HEJG8E+cy1GNmdpaZtTOzlusfiSeT+mGvvcLttb/8JaxYAQMGwBlnwOrVaScTkQxxisWJwGnAU8Ds6KHJC6T6bLEF3HtvmKq1cWO49lrYf394++20k4lIJE6bxXHu3r7co0MN5ZP6wgxGjgyN39tvDy++GAYlfPjhtJOJCPHaLK6soSwioS/GnDmh495nn8EvfgHnnANrfjQ0mIjUoDiXoaaZ2S/NzBJPIwKh1/eDD8IVV4T5Mq64Ag46CN59N+1kIvVWnGJxJnAvsNrMVpjZl2amu6IkWQ0awO9/D08+Ca1bh8tTnTvDNM2BJZKGrMXC3Tdz9wbu3tjdN4+WN6+JcCLstx+89BL87Gfw8cfQt28YX+q779JOJlKvxBl11szsODO7IFpuZ2bdko8mEtlmmzCe1PrBCMeMgZ//HD74IN1cIvVInMtQfyXMTXFMtLySMLe2SM1p2DAMRvjYY6F4PPFEuCz15JNpJxOpF+IUi+7ufhqwCsDdPwOaJJpKpDK9e4fe3gccAP/5T1i+7DLNkSGSsDjFYo2ZNQQcwMxaAfrNlPRsuy38+9/whz+EInHeeWGCpQ8/TDuZSJ0Vp1hcB9wPbGNmlwLPAJclmkokm0aN4NJLYcoU2GormDpVl6VEEhTnbqg7gbOBPwHvA0e4+71JBxOJpV+/cFlq//3h/ffDZamLLtLdUiLVLM6ZBe7+qruPc/cb3H1R3Dc3s75m9pqZLTazcyt4vZeZzTGztWY2oNxr25nZNDNbZGavRNOsivxY27ahwfv888EdLr44FA114hOpNrGKRVVE7RzjgH5AJ+BoM+tUbrNlwPHAXRW8xR3AWHffDegG6IK0VK5RI7jkknC3VGEhTJ8eLks9+mjayUTqhMSKBeEP/GJ3X+Luq4FSoH/mBu6+NJpI6QcN5lFRaeTuj0XbrXT3rxPMKnVF794wb973nfgOOQTOPhtbuzbtZCK1WqxiYWbbm9nB0fNmZrZZjN3aAO9kLC+P1sWxC/C5mf2fmb1kZmOjMxWR7AoL4V//gj/9KfTPGDuWzqefHub9FpEqMXff8AZmJwPDgJbuvqOZ7Qzc5O69s+w3EOjj7kOj5cFAN3cfWcG2E4GH3f2+aHkAcAuwN+FS1SRgirvfUm6/YVE2CgsLi0pLS7N/x5VYuXIlBQUFVd4/acpXNZsvWECnP/6RTT78kDUFBbz2+9/zca9eacf6kXw9fuspX27yOV9JSclsdy/OuqG7b/ABzCV0wnspY92CGPv1AKZmLI8CRlWy7URgQMbyPsCTGcuDgXEb+ryioiLPRVlZWU77J035cvDJJ/7Rvvu6h+Zv9xEj3L/5Ju1UP5DXx8+VL1f5nA+Y5Vn+nrt7rMtQ33pocwDAzBoRddDLYiaws5m1N7MmwCBgcoz91u/bIuoACHAQ8ErMfUV+qGVLXv7jH+Gaa8JMfDfcAD16wOuvp51MpNaIUyymm9kfgGZm9jPCcOUPZdvJ3dcCI4CpwCLgHndfaGZjzOxwADPrambLgYHAeDNbGO37HXAW8LiZLQAM+NvGf3siETM4/XR47jno0CH0zSgqgjvvTDuZSK3QKMY25wInAQuA4cAU4OY4b+7uU6LtM9eNzng+E2hbyb6PAXvG+RyR2IqLw0x8w4fDpElw3HGhj8Z110Hz5mmnE8lbcXpwr3P3v7n7QHcfED2PcxlKJD9tsQXcfTdMmACbbAK33vp9ERGRCsWZz2KBmc0v93jazK42s61qIqRItTODk0+GF1+ETp3g1Vdhn33g8ss1VIhIBeK0WTwKPAIcGz0eAp4C/kO4i0mk9tpjD5g1C0aOhDVr4NxzQ8e+ZcvSTiaSV+IUi57uPsrdF0SP84AD3f1yYIdk44nUgGbNQpvFlCnfDxWy555wV0Wj0IjUT3GKRYGZdV+/EE2pur53icZQkLqjXz9YsAD694cvvoBjjw2Pzz9PO5lI6uIUi6HAzWb2lpktJdwJdbKZNScMWy5Sd7RqBfffHxq/N900nF3stVc42xCpx+LcDTXT3fcAOgOd3X1Pd3/R3b9y93uSjyhSw9Y3fs+dC127hvaLkhIYNQpWr86+v0gdFHcgwUMJfSx+Y2ajzWx0tn1Ear2dd4Znn4ULLggF5M9/Dj2/X3017WQiNS7OrbM3AUcBIwk9qQcC2yecSyQ/NG4MY8bAU09B+/ahL0aXLvDXv4aRpkTqiThnFvu6+6+Bz9z9YsIAge2SjSWSZ3r2DJelhgyBb76B006Dww6DDz5IO5lIjYhTLFZFX782s9bAGqB9cpFE8tTmm8PEiXDPPdCiRbjVdo894KGsQ6WJ1HpxisVDZrYlMBaYAywF7k4ylEheGzgw3GLbuzd89BEcfngYa+rLL9NOJpKYDRYLM2sAPO7un7v7PwltFbtmDgYoUi+1aQPTpsFf/gJNmoRbbffYIwxKKFIHbbBYuPs64KqM5W/d/YvEU4nUBg0awG9/C7Nnh0bvt98OZxsjRsBXX6WdTqRaxbkMNc3Mfmlmlngakdropz+FGTPgkkvC3VPjxoXhQp5+Ou1kItUmTrE4kzDh0WozW2FmX5rZioRzidQujRvD+efDzJmhx/eSJXDAAeHM4+uv004nkrM4Pbg3c/cG7t7Y3TePljeviXAitc5ee4Vhz0ePDpeprrkG9t4bnn8+7WQiOYnTKc/M7DgzuyBabhcNJpiVmfU1s9fMbLGZnVvB673MbI6ZrTWzAeVe+87M5kaPuHN3i6SvSRO4+GJ44QXYffcw1/d++8HZZ8OqVdn3F8lDcS5D/ZXQEe+YaHklMC7bTmbWMNquH9AJONrMOpXbbBlwPFDRWNDfuHvn6HF4jJwi+aWoKDR+jxoVlseODWcZL76Ybi6RKohTLLq7+2lEnfPc/TOgSYz9ugGL3X2Ju68GSoH+mRu4+1J3nw+s27jYIrVE06Zw2WXhMtSuu4ZxpXr0gD/8Ab79Nu10IrHFKRZrorMEBzCzVsT7494GeCdjeXm0Lq5NzGyWmc0wsyM2Yj+R/NOtWxhX6qyzwphSf/qT5v2WWsU8y2BoZnYsYSDBLsDtwADgfHe/N8t+A4E+7j40Wh4MdHP3kRVsOxF42N3vy1jX2t3fM7MOwBNAb3d/s9x+w4BhAIWFhUWlpaVZvt3KrVy5koKCguwbpkT5cpNP+TZ/+WV2vfxyNl2+HG/QgLePO46FRxxB8xYt0o5WqXw6fhVRvqorKSmZ7e7FWTd096wPYFfgNGAEsFvMfXoAUzOWRwGjKtl2IjBgA++1wdfdnaKiIs9FWVlZTvsnTflyk3f5vvrK/Ywz3M3cwb/s0MF91qy0U1Uq745fOcpXdcAsj/E3Pc7dUNcCLd19nLvf4O6LYhasmcDOZtbezJoAg4BYdzWZWQszaxo93xroCbwS83NF8t+mm8LVV8OTT0KHDhQsWQLdu8M554RRbUXyTJw2iznA+dHtr2PNLPvpCuDuawlnIlOBRcA97r7QzMaY2eEAZtbVzJYT5sgYb2YLo913A2aZ2TygDPizu6tYSN3TqxfMn887AweGtowrrgh9NZ56Ku1kIj8Qp1Pe7e5+COHupteBy83sjThv7u5T3H0Xd9/R3S+N1o1298nR85nu3tbdm7v7Vu6+e7T+OXffw933ir7eUuXvUCTfNW/Om6eeCs89F/plvPFG6P196qmwQoMlSH6INa1qZCdC28UOgOaVFKlu3buHu6MuvDAMH3LjjWHcqSlT0k4mEqvNYv2ZxBhgIVDk7r9IPJlIfdSkCVx0UejM17UrvPMOHHooDB4MH3+cdjqpx+KcWbwF9HD3vu5+q7t/nnQokXpvjz1CR74rr4RmzeAf/4BOncIsfZr7W1IQp83iJuA7M+sWjeXUy8x61UA2kfqtYUP43e9g/nw48MAwK99RR8GRR8J776WdTuqZOJehhgJPEe5qujj6elGysUTkv3baCR5/HMaPD/OAP/hgOMu45RadZUiNiXMZ6nSgK/C2u5cAewMfJZpKRH6oQQMYNgwWLoTDDoMvvoChQ+Hgg8PcGSIJi1MsVrn7KgAza+rurwIdk40lIhVq2xYmT4a77oKttw5zfu+xB1x1Faxdm3Y6qcPiFIvlZrYl8ADwmJk9COiCqUhazODoo+GVV+CYY8JMfGedFYZEf/bZtNNJHRWngftId//c3S8CLgBuATQKrEjaWrWCO++ERx6B9u1DQ/h++8FJJ+k2W6l2G9MpD3ef7u6TPcxPISL54JBDQlvGBReEfhq33godO8LNN8M6TRUj1WOjioWI5KlmzWDMmHB2cfDB8OmncPLJ4Uxj3ry000kdoGIhUpd07AjTpkFpKWy7bejY16UL/Pa3GmdKcqJiIVLXmIXOe6++CqefHtZdc02Y1nXSJPXNkCpRsRCpqzbfPBSJ2bNhn33g/fdh0CDo0wdefz3tdFLLqFiI1HWdO4dbaidMgBYt4LHHQt+M0aM10ZLEpmIhUh80aBAavF97DU48EVavhksuCUOgP/po2umkFlCxEKlPWrUKY0o980w4u1iyJNx6e+SR8NZbaaeTPKZiIVIf9ewZ2jKuugoKCuCBB2C33cKlqa+/Tjud5KFEi4WZ9TWz16L5u8+t4PVeZjbHzNaa2YAKXt/czN41sxuSzClSLzVuDGeeGe6aOvZY+PbbcGlq113h3nt115T8QGLFwswaAuOAfkAn4Ggz61Rus2XA8cBdlbzNJcD0pDKKCNCmTZhc6ZlnYO+9w+x8v/oVHHQQLFiQdjrJE0meWXQDFrv7kmh4kFKgf+YG7r7U3ecDPxqTwMyKgEJgWoIZRWS9nj1h5swwb8ZWW8GTT4Y7qUaOpJE69NV75gmdakaXlfq6+9BoeTDQ3d1HVLDtROBhd78vWm4APAEMBnoDxZXsNwwYBlBYWFhUWlpa5bwrV66koKCgyvsnTflyo3wbp9GXX7LDbbfR5sEHsXXrWL3ZZrw1dCjvH3pomMEvz+Tb8Ssvn/OVlJTMdvfirBu6eyIPYCBwc8byYOD6SradCAzIWB4BnB09Px64IdvnFRUVeS7Kyspy2j9pypcb5aui+fPdDzzQPbRguHfu7P7002mn+pG8PX6RfM4HzPIYf9OTvAy1HGiXsdyW+PNg9ABGmNlS4Erg12b25+qNJyJZ7bEHPPEECy+8ENq1g7lzYf/9Q4P4u++mnU5qUJLFYiaws5m1N7MmwCBgcpwd3f1Yd9/O3XcAzgLucPcf3U0lIjXAjI8OPDDcNTV6NDRtGmbq69gR/vSncBeV1HmJFQt3X0u4nDQVWATc4+4LzWyMmR0OYGZdzWw54ZLVeDNbmFQeEcnRppvCxRfDokXwP/8DX30Ff/gD7L47PPigbrWt4xLtZ+HuU9x9F3ff0d0vjdaNdvfJ0fOZ7t7W3Zu7+1buvnsF7zHRK2jcFpGUtG8P//xnGGNqt93gzTfhiCPgwAPD3VRSJ6kHt4hUzcEHh4mVrr0WWraEp56Cbt3CyLZLlqSdTqqZioWIVF3jxvCb34Szi7PPDu0ZkyaFXuBnngmffJJ2QqkmKhYikrstt4TLLw/zZAweDGvXwtVXw047wdixsGpV2gklRyoWIlJ9ttsO7rgjDFLYuzd8/nk449h1V7jzTlj3o8EapJZQsRCR6rf33qEB/NFHQ1+Nt9+G446Drl3hiSfSTidVoGIhIskwg7594aWX4NZboXVrmDMnnHEccgi8/HLaCWUjqFiISLIaNoQTToA33oA//hE22yyccey1FwwdCu/FHdhB0qRiISI1Y9NN4bzzYPFiOO20MNXrLbeERvDzzw/tG5K3VCxEpGZtsw3ccAMsXBh6gn/zDVx6KXToEIYP+eqrtBNKBVQsRCQdu+wSeoI/+ywccAB89lkYPmTHHeG66zTmVJ5RsRCRdO27L5SVwbRp4W6pDz6A008PxeSWW0KfDUmdioWIpM8MfvYzeOEFeOAB+OlPYdmy0ADeqROUlqqPRspULEQkf5hB//5hzKm77gqN32+8AUcfHfpuTJ6s0W1TomIhIvmnQYNQIF55Bf72N2jbFubPD4WkRw94/PG0E9Y7KhYikr8aNw6Xot54A665JtxJ9cILYcTbgw6C559PO2G9oWIhIvlvk01Co/ebb8Jll4WBC8vKQuP4YYdR8MYbaSes81QsRKT2KCiAUaPgrbdCB7/mzeGRRygeNgwOPxxefDHthHVWosUu5DBkAAANx0lEQVTCzPqa2WtmttjMfjSHtpn1MrM5ZrbWzAZkrN/ezGab2VwzW2hmpySZU0RqmS23DEOHLFkCZ57Jd02bwkMPQffu0KcPPPNM2gnrnMSKhZk1BMYB/YBOwNFm1qncZsuA44G7yq1/H9jX3TsD3YFzzax1UllFpJbaZhu46ipm3H03nHNOOPOYNg323z9M8/r447p7qpokeWbRDVjs7kvcfTVQCvTP3MDdl7r7fGBdufWr3X19982mCecUkVpuTYsW8Oc/h6HQR4+GLbaA6dNDQ3jPnjBliopGjpL8I9wGeCdjeXm0LhYza2dm86P3uNzdNTSliGxYy5Zw8cWhaFx6KWy1Vbhj6tBDobgY7r9fnfuqyDyhamtmA4E+7j40Wh4MdHP3kRVsOxF42N3vq+C11sADwC/c/YNyrw0DhgEUFhYWlZaWVjnvypUrKSgoqPL+SVO+3ChfbmprvobffMO2Dz3EdqWlNPnss7Bt+/a8fdxxfHTAAWH49BTz5YOSkpLZ7l6cdUN3T+QB9ACmZiyPAkZVsu1EYMAG3uu2Db3u7hQVFXkuysrKcto/acqXG+XLTa3P9/XX7tdf7962rXu4IOXesaP77be7r1mTfr4UAbM8xt/0JC9DzQR2NrP2ZtYEGARMjrOjmbU1s2bR8xZAT+C1xJKKSN3WrBmMGBHm0hg/HnbYAV57DYYMCQMWjh8fhkqXSiVWLNx9LTACmAosAu5x94VmNsbMDgcws65mthwYCIw3s4XR7rsBL5jZPGA6cKW7L0gqq4jUE02bwrBh8PrrMHFiKBRvvQWnnALbbx/aOz76KO2UeSnRu4zcfYq77+LuO7r7pdG60e4+OXo+093buntzd9/K3XeP1j/m7nu6+17R1wlJ5hSReqZx43BW8corYUTboqJQJC66CLbbDoYPD2ce8l+6JVVE6q+GDeGoo2DmTHjySfjFL2DVKpgwAXbdNSxPn67bblGxEBEJQ6MfcEAYAv3VV8OZxSabwMMPh859XbvC3XfDmjVpJ02NioWISKaOHeGmm8LkSxddBK1awezZcMwxYcrXq66CFSvSTlnjVCxERCrSqhVceGHo4Lf+stQ778BZZ0G7duHrO+9kf586QsVCRGRDmjWDk0+GhQvDYIUHHhjOLK66Ctq3D2ccs2ennTJxKhYiInE0aACHHRbm0Zg1KxQJCG0ZxcVhDKpJk+psu4aKhYjIxioqgjvvDEOk/+53YeDC556DQYNCh79LL61z/TVULEREqmq77eDKK2H5crjxRthtN3jvPTj//NCuccIJMGdO2imrhYqFiEiuCgpCL/CFC+Hf/w6z9q1eHXqJFxWx98iRcM89tfoSlYqFiEh1MYPeveHBB8M4VGeeCVtswRYvvxw6/7VvH+YQr4WXqFQsRESS0KFDuGNq+XJeP+OMcOvtu++GucPbtYMTT4SXXko7ZWwqFiIiSSoo4L3+/cM4VNOmhTuqVq+G226DLl3CFLCTJoV1eUzFQkSkJpjBz34W+mq8/jqccQZsvjk880y4i6ptWzj77PBaHlKxEBGpaTvtBFdfHe6iGjcO9tgjtGOMHRuGGykpCf03Vq1KO+l/qViIiKRls83g1FNh3jyYMSO0Y2y6aRgB95hjwtnGmWfCokVpJ1WxEBFJnRl07w633BL6afz1r9C5M3zySTgD6dQJevWCf/wjtRn9VCxERPLJFlvA//5v6Mw3c2YYl6p5c3j6aRg8GFq3htNPh5dfrtFYiRYLM+trZq+Z2WIzO7eC13uZ2RwzW2tmAzLWdzaz581soZnNN7OjkswpIpJ3zMKYUxMmwPvvh6/FxfD553DddaGdo2fP0PHv668Tj5NYsTCzhsA4oB/QCTjazDqV22wZcDxwV7n1XwO/jqZZ7QtcY2ZbJpVVRCSvbbZZOMOYOTOMcHvKKWHdc8+FIUVatw6TNiUoyTOLbsBid1/i7quBUqB/5gbuvtTd5wPryq1/3d3fiJ6/B3wItEowq4hI7dClSxiH6r33QhtH9+7QsiXsskuiH5tksWgDZM4Msjxat1HMrBvQBHizmnKJiNR+BQXh7qkZM8KQ6Q2SbYI2T2gicjMbCPRx96HR8mCgm7uPrGDbicDD7n5fufXbAk8CQ9x9RgX7DQOGARQWFhaVlpZWOe/KlSspKCio8v5JU77cKF9ulC83+ZyvpKRktrsXZ93Q3RN5AD2AqRnLo4BRlWw7ERhQbt3mwBxgYJzPKyoq8lyUlZXltH/SlC83ypcb5ctNPucDZnmMv7FJnrfMBHY2s/Zm1gQYBEyOs2O0/f3AHe5+b4IZRUQkhsSKhbuvBUYAU4FFwD3uvtDMxpjZ4QBm1tXMlgMDgfFmtjDa/VdAL+B4M5sbPTonlVVERDasUZJv7u5TgCnl1o3OeD4TaFvBfv8A/pFkNhERiU89uEVEJCsVCxERyUrFQkREskqsn0VNM7OPgLdzeIutgY+rKU4SlC83ypcb5ctNPufb3t2zjpBRZ4pFrsxslsfpmJIS5cuN8uVG+XKT7/ni0GUoERHJSsVCRESyUrH43oS0A2ShfLlRvtwoX27yPV9WarMQEZGsdGYhIiJZ1atiEWOa16ZmNil6/QUz26EGs7UzszIzWxRNJ3t6BdscaGZfZIyXNbqi90o451IzWxB9/qwKXjczuy46hvPNrEsNZuuYcWzmmtkKMzuj3DY1egzN7FYz+9DMXs5Y19LMHjOzN6KvLSrZd0i0zRtmNqQG8401s1ejf7/7K5ulMtvPQoL5LjKzdzP+DQ+pZN8N/r4nmG9SRralZja3kn0TP37VKs7QtHXhATQkTKDUgTCZ0jygU7ltTgVuip4PAibVYL5tgS7R882A1yvIdyBh3o80j+NSYOsNvH4I8ChgwD7ACyn+e/+HcA95aseQMCBmF+DljHVXAOdGz88FLq9gv5bAkuhri+h5ixrK93OgUfT88oryxflZSDDfRcBZMf79N/j7nlS+cq9fBYxO6/hV56M+nVlkneY1Wr49en4f0NvMrCbCufv77j4nev4lYaTejZ5ZMA/0Jwwt7x4mrNoymsSqpvUG3nT3XDpq5szdnwI+Lbc68+fsduCICnbtAzzm7p+6+2fAY4T56BPP5+7TPIwaDTCDCgb7rCmVHL844vy+52xD+aK/Hb8C7q7uz01DfSoWcaZ5/e820S/LF8BWNZIuQ3T5a2/ghQpe7mFm88zsUTPbvUaDBQ5MM7PZ0UyF5VXLdLrVYBCV/5KmfQwL3f19CP9JALapYJt8OY4nEs4UK5LtZyFJI6LLZLdWchkvH47f/sAH7v5GJa+nefw2Wn0qFhWdIZS/FSzONokyswLgn8AZ7r6i3MtzCJdV9gKuBx6oyWyRnu7eBegHnGZmvcq9ng/HsAlwOFDRxFn5cAzjyIfjeB6wFrizkk2y/Swk5UZgR6Az8D7hUk95qR8/4Gg2fFaR1vGrkvpULJYD7TKW2wLvVbaNmTUCtqBqp8BVYmaNCYXiTnf/v/Kvu/sKd18ZPZ8CNDazrWsqX/S570VfPyTMZtit3CZxjnPS+gFz3P2D8i/kwzEEPlh/aS76+mEF26R6HKMG9cOAYz26wF5ejJ+FRLj7B+7+nbuvA/5WyeemffwaAf8DTKpsm7SOX1XVp2IRZ5rXycD6u04GAE9U9otS3aLrm7cAi9z9L5Vs85P1bShm1o3w7/dJTeSLPrO5mW22/jmhIfTlcptNBn4d3RW1D/DF+ksuNajS/9GlfQwjmT9nQ4AHK9hmKvBzM2sRXWb5ebQucWbWFzgHONzdv65kmzg/C0nly2wDO7KSz63ytM7V5GDgVXdfXtGLaR6/Kku7hb0mH4Q7dV4n3CVxXrRuDOGXAmATwqWLxcCLQIcazLYf4TR5PjA3ehwCnAKcEm0zAlhIuLNjBrBvDR+/DtFnz4tyrD+GmRkNGBcd4wVAcQ1n3JTwx3+LjHWpHUNC0XofWEP43+5JhHawx4E3oq8to22LgZsz9j0x+llcDJxQg/kWE673r/85XH+HYGtgyoZ+Fmoo39+jn635hAKwbfl80fKPft9rIl+0fuL6n7mMbWv8+FXnQz24RUQkq/p0GUpERKpIxUJERLJSsRARkaxULEREJCsVCxERyUrFQiQPRKPhPpx2DpHKqFiIiEhWKhYiG8HMjjOzF6M5CMabWUMzW2lmV5nZHDN73MxaRdt2NrMZGfNCtIjW72Rm/44GM5xjZjtGb19gZvdFc0ncWVMjHovEoWIhEpOZ7QYcRRgArjPwHXAs0JwwFlUXYDpwYbTLHcA57r4nocfx+vV3AuM8DGa4L6EHMISRhs8AOhF6+PZM/JsSialR2gFEapHeQBEwM/pPfzPCIIDr+H7AuH8A/2dmWwBbuvv0aP3twL3ReEBt3P1+AHdfBRC934sejSUUza62A/BM8t+WSHYqFiLxGXC7u4/6wUqzC8ptt6ExdDZ0aenbjOffod9PySO6DCUS3+PAADPbBv47l/b2hN+jAdE2xwDPuPsXwGdmtn+0fjAw3cMcJcvN7IjoPZqa2aY1+l2IVIH+5yISk7u/YmbnE2Y3a0AYafQ04CtgdzObTZhd8aholyHATVExWAKcEK0fDIw3szHRewyswW9DpEo06qxIjsxspbsXpJ1DJEm6DCUiIlnpzEJERLLSmYWIiGSlYiEiIlmpWIiISFYqFiIikpWKhYiIZKViISIiWf0/nDL3AjYUKwIAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "idx = np.arange(train_data.size()[0])\n",
    "avg_loss_list = list()\n",
    "epoch_list = list()\n",
    "for epoch in range(num_epoch):\n",
    "    total_loss = 0\n",
    "    np.random.shuffle(idx)\n",
    "    for id in idx:\n",
    "        y_pred = model(train_data[id,:NUM_FEATURE])\n",
    "        y = train_data[id,NUM_FEATURE:]\n",
    "        loss = loss_fn(y_pred, y)\n",
    "        total_loss += loss\n",
    "\n",
    "        # Before the backward pass, use the optimizer object to zero all of the\n",
    "        # gradients for the Tensors it will update (which are the learnable weights\n",
    "        # of the model)\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # Backward pass: compute gradient of the loss with respect to model parameters\n",
    "        loss.backward()\n",
    "\n",
    "        # Calling the step function on an Optimizer makes an update to its parameters\n",
    "        optimizer.step()\n",
    "\n",
    "    avg_loss = total_loss/train_data.size()[0]\n",
    "    avg_loss_list.append(avg_loss)\n",
    "    epoch_list.append(epoch)\n",
    "    \n",
    "# Plot loss\n",
    "plt.plot(epoch_list, avg_loss_list, 'r-', lw=2)\n",
    "plt.xlabel(\"epoch\")\n",
    "plt.ylabel(\"average error\")\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test model on unseen data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_class = list()\n",
    "actual_class = list()\n",
    "for i in range(test_data.shape[0]):\n",
    "    y_pred_test = model(test_data[i,:NUM_FEATURE])\n",
    "    val_pred, indices_pred = torch.max(y_pred_test, 0)\n",
    "    val_actual, indices_actual = torch.max(test_data[i,NUM_FEATURE:], 0)\n",
    "    pred_class.append(indices_pred)\n",
    "    actual_class.append(indices_actual)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Confusion matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Confusion matrix:\n",
      " [[10.  2.  0.]\n",
      " [ 0.  5.  0.]\n",
      " [ 0.  3. 10.]]\n"
     ]
    }
   ],
   "source": [
    "conf_mat = np.zeros((NUM_CLASS, NUM_CLASS))\n",
    "for i,j in zip(pred_class, actual_class):\n",
    "#     print(\"Predicted class: {}, Actual class: {}\\n\".format(i, j))\n",
    "    conf_mat[i, j] += 1\n",
    "print(\"Confusion matrix:\\n\", conf_mat)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
